# -*- coding: utf-8 -*-
"""
Created on Mon Dec 20 10:54:31 2021

@author: wulong
"""
import numpy as np
from casadi import *
from Medium_model import* 

# In[1] Medium MPC
def formulate_opt_m(initial_state, state_guess, state_bnds_lo, state_bnds_up, 
                    slow_state_bnds, 
                    fast_state_guess, fast_state_bnds_lo, fast_state_bnds_up, 
                    fast_state_slack_lo, fast_state_slack_up, 
                    slow_input_bnds, 
                    input_guess, input_bnds_lo, input_bnds_up, 
                    fast_input_guess, fast_input_bnds_lo, fast_input_bnds_up, 
                    input_bin_bnds,
                    distb_bnds,
                    output_guess, 
                    output_setpnts, medium_state_setpnts, fast_state_setpnts, 
                    medium_input_setpnts, fast_input_setpnts, 
                    alpha, alpha_xm, alpha_xf, alpha_um, alpha_uf, 
                    alpha_slack, pred_horzn):
    w = []
    w0 = []
    lbw = []
    ubw = []
    J = 0
    g = []
    lbg = []
    ubg = []
    
    # "Lift" initial states conditions
    xname = 'X' + str(0)
    Xk = MX.sym(xname, Nx_m)
    w += [Xk]
    lbw += [initial_state]
    ubw += [initial_state]
    w0 += [initial_state]
    
    for k in range(1, pred_horzn+1):
        # Define binary for inputs
        Zu_m = input_bin_bnds[1]
        Zu_f = vertcat(input_bin_bnds[0], input_bin_bnds[2])
        
        # Create input variables
        uname = 'U' + str(k-1)
        Uk = MX.sym(uname, Nuc_m)
        w += [Uk]
        lbw +=[Zu_m*input_bnds_lo]
        ubw +=[Zu_m*input_bnds_up]
        w0 += [input_guess]
        
        ufname = 'U_f' + str(k-1)
        Ufk = MX.sym(ufname, Nuc_f)
        w += [Ufk]
        lbw +=[Zu_f*fast_input_bnds_lo]
        ubw +=[Zu_f*fast_input_bnds_up]
        w0 += [fast_input_guess]
        
        # Create fast states variables
        xfname = 'X_f' + str(k)
        Xfk = MX.sym(xfname, Nx_f)
        w += [Xfk]
        lbw += [fast_state_bnds_lo]
        ubw += [fast_state_bnds_up]
        w0 += [fast_state_guess]
        
        # Create disturbance
        Dk = distb_bnds[k-1,:]
        
        # Simulate the model
        Ik = I_ode_m(x0 = Xk, p = vertcat(slow_state_bnds, Xfk, slow_input_bnds, Uk, Ufk, 
                                          input_bin_bnds, Dk))
        X_int = Ik['xf']
        
        # Create new states variables
        xname = 'X' + str(k)
        Xk = MX.sym(xname, Nx_m)
        w += [Xk]
        lbw += [state_bnds_lo]
        ubw += [state_bnds_up]
        w0 += [state_guess]
        
        # Add dynamic constraints
        g += [X_int - Xk]
        lbg += [[0]*Nx_m]
        ubg += [[0]*Nx_m]
        
        g += [const_ies_m(Xk, slow_state_bnds, Xfk, slow_input_bnds, Uk, Ufk, 
                          input_bin_bnds, Dk)]
        lbg += [[0]*Nx_f]
        ubg += [[0]*Nx_f]
        
        # Create outputs variables
        yname = 'y' + str(k)
        Yk = MX.sym(yname, Ny_m)
        w += [Yk]
        lbw += [[-np.inf]*Ny_m] 
        ubw += [[np.inf]*Ny_m]
        w0 += [output_guess]
        
        # Add output constraints
        g += [Yk - out_ies_m(Xk, slow_state_bnds, Xfk, slow_input_bnds, Uk, Ufk, 
                             input_bin_bnds, Dk)]
        lbg += [[0]*Ny_m]
        ubg += [[0]*Ny_m]
        
        # Add slack variables for fast states
        ename = 'e' + str(k)
        ek = MX.sym(ename, 4)
        w += [ek]
        lbw += [[0]*4] 
        ubw += [[np.inf]*4]
        w0 += [[0]*4]
        
        Xfk_slack = vertcat(Xfk[7], Xfk[8])
        ek_lo = vertcat(ek[0], ek[2])
        ek_up = vertcat(ek[1], ek[3])
        g += [Xfk_slack + ek_lo - ek_up]
        lbg += [fast_state_slack_lo]
        ubg += [fast_state_slack_up]
        
        # Create tracking devation term for medium and fast states and inputs
        delta_Xm = Xk - medium_state_setpnts
        delta_Xf = Xfk - fast_state_setpnts
        
        delta_Um = Uk - medium_input_setpnts
        delta_Uf = Ufk - fast_input_setpnts
        
        # The cost funtion
        J += alpha[0] * (Yk[0] - output_setpnts[k-1])**2
        J += alpha[1] * (Ufk[0] + Uk)
        J += mtimes(mtimes(delta_Xm.T, alpha_xm), delta_Xm)
        J += mtimes(mtimes(delta_Xf.T, alpha_xf), delta_Xf)
        J += mtimes(mtimes(delta_Um.T, alpha_um), delta_Um)
        J += mtimes(mtimes(delta_Uf.T, alpha_uf), delta_Uf)
        J += mtimes(mtimes(ek.T, alpha_slack), ek)
        
        pass

    # Concatenate decision variables and constraint terms
    w = vertcat(*w)
    lbw = vertcat(*lbw)
    ubw = vertcat(*ubw)
    w0 = vertcat(*w0)
    g = vertcat(*g)
    lbg = vertcat(*lbg)
    ubg = vertcat(*ubg)
    
    return w, lbw, ubw, w0, g, lbg, ubg, J

# In[2] Creat Medium MPC solver
def solve_opt_m(w, lbw, ubw, w0, g, lbg, ubg, J):
    print('Creat an M-MPC solver')
    nlp_prob = {'f': J, 'x': w, 'g': g}
    nlp_solver = nlpsol('nlp_solver', 'ipopt', nlp_prob)
    # Solve NLP
    print('Solve M-MPC')
    # nlp_solver.stats()
    sol = nlp_solver(x0=w0, lbx=lbw, ubx=ubw, lbg=lbg, ubg=ubg)
    print(nlp_solver.stats())
    optimalValues = sol['x'].full().ravel()
    
    return optimalValues
